{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b8d7b17-af50-42cd-b531-ef61c49c9e61",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the work directory to the imaginaire root.\n",
    "import os, sys, time\n",
    "import pathlib\n",
    "root_dir = pathlib.Path().absolute().parents[2]\n",
    "os.chdir(root_dir)\n",
    "print(f\"Root Directory Path: {root_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5b9e2f-841c-4815-92e0-0c76ed46da62",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import Python libraries.\n",
    "import numpy as np\n",
    "import torch\n",
    "import k3d\n",
    "import json\n",
    "import plotly.graph_objs as go\n",
    "from collections import OrderedDict\n",
    "# Import imaginaire modules.\n",
    "from projects.nerf.utils import camera, visualize\n",
    "from third_party.colmap.scripts.python.read_write_model import read_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76033016-2d92-4a5d-9e50-3978553e8df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the COLMAP data.\n",
    "colmap_path = \"datasets/lego_ds2\"\n",
    "cameras, images, points_3D = read_model(path=f\"{colmap_path}/sparse\", ext=\".bin\")\n",
    "# Convert camera poses.\n",
    "images = OrderedDict(sorted(images.items()))\n",
    "qvecs = torch.from_numpy(np.stack([image.qvec for image in images.values()]))\n",
    "tvecs = torch.from_numpy(np.stack([image.tvec for image in images.values()]))\n",
    "Rs = camera.quaternion.q_to_R(qvecs)\n",
    "poses = torch.cat([Rs, tvecs[..., None]], dim=-1)  # [N,3,4]\n",
    "print(f\"# images: {len(poses)}\")\n",
    "# Get the sparse 3D points and the colors.\n",
    "xyzs = torch.from_numpy(np.stack([point.xyz for point in points_3D.values()]))\n",
    "rgbs = np.stack([point.rgb for point in points_3D.values()])\n",
    "rgbs_int32 = (rgbs[:, 0] * 2**16 + rgbs[:, 1] * 2**8 + rgbs[:, 2]).astype(np.uint32)\n",
    "print(f\"# points: {len(xyzs)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47862ee1-286c-4877-a181-4b33b7733719",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_depth = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6cf60ec-fe6a-43ba-9aaf-e3c7afd88208",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the bounding sphere.\n",
    "json_fname = f\"{colmap_path}/transforms.json\"\n",
    "with open(json_fname) as file:\n",
    "    meta = json.load(file)\n",
    "center = meta[\"sphere_center\"]\n",
    "radius = meta[\"sphere_radius\"]\n",
    "# ------------------------------------------------------------------------------------\n",
    "# These variables can be adjusted to make the bounding sphere fit the region of interest.\n",
    "# The adjusted values can then be set in the config as data.readjust.center and data.readjust.scale\n",
    "readjust_center = np.array([0., 0., 0.])\n",
    "readjust_scale = 1.\n",
    "# ------------------------------------------------------------------------------------\n",
    "center += readjust_center\n",
    "radius *= readjust_scale\n",
    "# Make some points to hallucinate a bounding sphere.\n",
    "sphere_points = np.random.randn(100000, 3)\n",
    "sphere_points = sphere_points / np.linalg.norm(sphere_points, axis=-1, keepdims=True)\n",
    "sphere_points = sphere_points * radius + center"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e986aed0-1aaf-4772-937c-136db7f2eaec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can choose to visualize with Plotly...\n",
    "x, y, z = *xyzs.T,\n",
    "colors = rgbs / 255.0\n",
    "sphere_x, sphere_y, sphere_z = *sphere_points.T,\n",
    "sphere_colors = [\"#4488ff\"] * len(sphere_points)\n",
    "traces_poses = visualize.plotly_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.05)\n",
    "trace_points = go.Scatter3d(x=x, y=y, z=z, mode=\"markers\", marker=dict(size=1, color=colors, opacity=1), hoverinfo=\"skip\")\n",
    "trace_sphere = go.Scatter3d(x=sphere_x, y=sphere_y, z=sphere_z, mode=\"markers\", marker=dict(size=0.5, color=sphere_colors, opacity=0.7), hoverinfo=\"skip\")\n",
    "traces_all = traces_poses + [trace_points, trace_sphere]\n",
    "layout = go.Layout(scene=dict(xaxis=dict(showspikes=False, backgroundcolor=\"rgba(0,0,0,0)\", gridcolor=\"rgba(0,0,0,0.1)\"),\n",
    "                              yaxis=dict(showspikes=False, backgroundcolor=\"rgba(0,0,0,0)\", gridcolor=\"rgba(0,0,0,0.1)\"),\n",
    "                              zaxis=dict(showspikes=False, backgroundcolor=\"rgba(0,0,0,0)\", gridcolor=\"rgba(0,0,0,0.1)\"),\n",
    "                              xaxis_title=\"X\", yaxis_title=\"Y\", zaxis_title=\"Z\", dragmode=\"orbit\",\n",
    "                              aspectratio=dict(x=1, y=1, z=1), aspectmode=\"data\"), height=800)\n",
    "fig = go.Figure(data=traces_all, layout=layout)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdde170b-4546-4617-9162-a9fcb936347d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ... or visualize with K3D.\n",
    "plot = k3d.plot(name=\"poses\", height=800, camera_rotate_speed=5.0, camera_zoom_speed=3.0, camera_pan_speed=1.0)\n",
    "k3d_objects = visualize.k3d_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.05)\n",
    "for k3d_object in k3d_objects:\n",
    "    plot += k3d_object\n",
    "plot += k3d.points(xyzs, colors=rgbs_int32, point_size=0.02, shader=\"flat\")\n",
    "plot += k3d.points(sphere_points, color=0x4488ff, point_size=0.01, shader=\"flat\")\n",
    "plot.display()\n",
    "plot.camera_fov = 30.0"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
